import os
import numpy as np
import argparse
import time
import torch
import model.config as cfg

from torch import optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from model.model import Yolov2
from model.loss import YoloLoss
from model.dataset import VOCDataset, detection_collate
from model.utils.train import adjust_learning_rate


def train(args):
    # define the hyper parameters first
    args.lr = cfg.lr
    args.decay_lrs = cfg.decay_lrs
    args.weight_decay = cfg.weight_decay
    args.momentum = cfg.momentum
    args.batch_size = cfg.batch_size

    print("Called with args:")
    print(args)

    # initial tensorboardX writer
    if args.use_tfboard:
        if args.exp_name == "default":
            writer = SummaryWriter()
        else:
            writer = SummaryWriter("runs/" + args.exp_name)

    output_dir = args.output_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # load dataset
    print("loading dataset....")
    train_dataset = VOCDataset(args.dataset)
    print("dataset loaded.")

    print("training number: {}".format(len(train_dataset)))
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=detection_collate,
        drop_last=True,
    )

    # initialize the model
    print("initialize the model")
    tic = time.time()
    model = Yolov2().to(args.device)
    toc = time.time()
    print("model loaded: cost time {:.2f}s".format(toc - tic))

    # loss function
    loss_fn = YoloLoss().to(args.device)

    # global learning rate
    lr = args.lr

    # initialize the optimizer
    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    # resume from the previous checkpoint
    if args.resume:
        ckp_path = get_checkpoint_file(output_dir, args.start_epoch - 1)
        if os.path.exists(ckp_path):
            print("resume from {}".format(ckp_path))
            checkpoint = torch.load(ckp_path)
            model.load_state_dict(checkpoint["model"])

            lr = checkpoint["lr"]
            print("learning rate is {}".format(lr))
            adjust_learning_rate(optimizer, lr)

    # set the model mode to train because we have some layer, such as BN layer
    # whose behaviors are different when in training and testing.
    model.train()

    # start training
    for epoch in range(args.start_epoch, args.epochs + 1):
        # adjust super-parameters
        if epoch in args.decay_lrs:
            lr = args.decay_lrs[epoch]
            adjust_learning_rate(optimizer, lr)
            print("adjust learning rate to {}".format(lr))

        if cfg.multi_scale and epoch in cfg.epoch_scale:
            cfg.scale_range = cfg.epoch_scale[epoch]
            print("change scale range to {}".format(cfg.scale_range))

        loss_temp = 0
        tic = time.time()
        iters_per_epoch = int(len(train_dataset) / args.batch_size)
        train_data_iter = iter(train_dataloader)
        for step in range(iters_per_epoch):
            if cfg.multi_scale and (step + 1) % cfg.scale_step == 0:
                scale_index = np.random.randint(*cfg.scale_range)
                cfg.input_size = cfg.input_sizes[scale_index]
                print("change input size {}".format(cfg.input_size))

            # fetch a batch
            im_data, boxes, gt_classes, num_obj = next(train_data_iter)
            im_data = im_data.to(args.device)
            boxes = boxes.to(args.device)
            gt_classes = gt_classes.to(args.device)
            num_obj = num_obj.to(args.device)

            # forward and loss
            pred = model(im_data)
            box_loss, iou_loss, class_loss = loss_fn(pred, boxes, gt_classes, num_obj)
            loss = box_loss.mean() + iou_loss.mean() + class_loss.mean()

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # log and print
            loss_temp += loss.item()
            if (step + 1) % args.display_interval == 0:
                toc = time.time()
                loss_temp /= args.display_interval

                iou_loss_v = iou_loss.mean().item()
                box_loss_v = box_loss.mean().item()
                class_loss_v = class_loss.mean().item()

                print(
                    "[epoch %2d][step %4d/%4d] loss: %.4f, lr: %.2e, time cost %.1fs "
                    "iou_loss: %.4f, box_loss: %.4f, cls_loss: %.4f"
                    % (
                        epoch,
                        step + 1,
                        iters_per_epoch,
                        loss_temp,
                        lr,
                        toc - tic,
                        iou_loss_v,
                        box_loss_v,
                        class_loss_v,
                    )
                )

                if args.use_tfboard:
                    n_iter = (epoch - 1) * iters_per_epoch + step + 1
                    writer.add_scalar("losses/loss", loss_temp, n_iter)
                    writer.add_scalar("losses/iou_loss", iou_loss_v, n_iter)
                    writer.add_scalar("losses/box_loss", box_loss_v, n_iter)
                    writer.add_scalar("losses/cls_loss", class_loss_v, n_iter)

                loss_temp = 0
                tic = time.time()

        # save checkpoint
        if epoch % args.save_interval == 0:
            save_name = get_checkpoint_file(output_dir, epoch)
            torch.save(
                {
                    "model": model.state_dict(),
                    "epoch": epoch,
                    "lr": lr,
                },
                save_name,
            )


def get_checkpoint_file(output_dir, epoch):
    return os.path.join(output_dir, "yolov2_epoch_{}.pth".format(epoch))


def parse_args():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(description="Yolo v2")

    # fmt: off
    parser.add_argument('--dataset', default='data/train.txt', type=str)
    parser.add_argument('--device', default='cuda:0', type=str)
    parser.add_argument('--epochs', default=160, type=int, help='number of epochs to train')
    parser.add_argument('--num-workers', default=8, type=int, help='number of workers to load training data')
    parser.add_argument('--output-dir', default='output', type=str)
    parser.add_argument('--display-interval', default=10, type=int)
    parser.add_argument('--save-interval', default=20, type=int)
    parser.add_argument('--start-epoch', default=1, type=int)
    parser.add_argument('--resume', default=True, type=bool)
    parser.add_argument('--use-tfboard', default=False, type=bool)
    parser.add_argument('--exp-name', default='default', type=str)
    # fmt: on

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    train(args)
